{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## **演示0403：矩阵的链式求导**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### **问题提出：**\n",
    "在机器学习算法中，经常需要针对矩阵进行求导，例如：已知$(N,D)$维度矩阵$X$，$(D,M)$维度矩阵$W$以及$(N,M)$维度矩阵$Y=XW$  \n",
    "假设标量 $L=\\sum_{i=1}^{N}\\sum_{j=1}^{M}f(y_{ij}) $，其中，$y_{ij}$表示矩阵$Y$中的第$(i,j)$个元素，$f(y_{ij})$是基于$y_{ij}$的某种函数定义。也就是说，$L$是由许多个$f(y_{ij})$函数值求和而得。  \n",
    "要求计算：$\\dfrac{\\partial L}{\\partial X}$和$\\dfrac{\\partial L}{\\partial W}$  \n",
    "注意：\n",
    "* 因为$X$维度为$(N,D)$，而$L$是标量，因此$\\dfrac{\\partial L}{\\partial X}$实际上可看成$L$针对$X$中的每一个元素分别求导，最终形成$(N,D)$矩阵\n",
    "* 因为$W$维度为$(D,M)$，而$L$是标量，因此$\\dfrac{\\partial L}{\\partial W}$形成$(D,M)$矩阵\n",
    "* 但是由于$L$并非直接由$X$和$W$计算而来，而是经由中间变量$Y$，根据链式法则，必须要分别计算出$\\dfrac{\\partial L}{\\partial Y}$和$\\dfrac{\\partial Y}{\\partial X}$，才能使用链式法则求出$\\dfrac{\\partial L}{\\partial X}$。对于$\\dfrac{\\partial L}{\\partial W}$也类似\n",
    "* $Y$维度为$(N,M)$，$X$维度为$(N,D)$，$\\dfrac{\\partial Y}{\\partial X}$计算方法是：$Y$矩阵中的每个元素都需要分别针对$X$矩阵中的每个元素求导，最终将会有$N \\times M \\times N \\times D$个导数值，这些导数值将构成**Jacobian矩阵**。\n",
    "* 在机器学习中，$N$、$D$、$M$的值可能会比较大，例如：$N=64,D=4096,M=4096$，则$N \\times M \\times N \\times D=64 \\times 1024 \\times 1024 \\times 1024$个元素，如果每个元素用32位浮点数存储(4字节)，则共消耗$ 4 \\times 64 \\times 1024 \\times 1024 \\times 1024=256GB $内存，这是不现实的\n",
    "* 因此，必须要找出一种办法，在不计算Jacobian矩阵的情况下，求解偏导数。而因为$Y$是$X$和$W$经过线性变换而得的，故存在着化简的可能性"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### **实验1：推导矩阵链式求导公式**\n",
    "已知$(N,D)$维度矩阵$X$，$(D,M)$维度矩阵$W$以及(N,M)维度矩阵$Y=XW$，同时，$L$是关于$Y$的函数，则：\n",
    "* $ \\dfrac{\\partial L}{\\partial X}=\\dfrac{\\partial L}{\\partial Y}W^T $\n",
    "* $ \\dfrac{\\partial L}{\\partial W}=X^T\\dfrac{\\partial L}{\\partial Y} $\n",
    "\n",
    "针对上一结论的变形：已知$(N,D)$矩阵$W$，$(D,M)$矩阵$X$，以及矩阵$(N,M)$矩阵$Y=WX$，同时，$L$是关于$Y$的函数，则：\n",
    "* $ \\dfrac{\\partial L}{\\partial W}=\\dfrac{\\partial L}{\\partial Y}X^T $\n",
    "* $ \\dfrac{\\partial L}{\\partial X}=W^T\\dfrac{\\partial L}{\\partial Y} $\n",
    "\n",
    "下面的步骤将对上述公式进行模拟推导，为了演示方便，仅考查$N=2, D=4, M=3$的特定情况："
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "> **步骤1：给定样例**  \n",
    "$ X=\\left(\\begin{matrix}\n",
    "x_{11} & x_{12} & x_{13} & x_{14} \\\\\n",
    "x_{21} & x_{22} & x_{23} & x_{24}\n",
    "\\end{matrix}\\right) $  \n",
    "$ W=\\left(\\begin{matrix}\n",
    "w_{11} & w_{12} & w_{13} \\\\\n",
    "w_{21} & w_{22} & w_{23} \\\\\n",
    "w_{31} & w_{32} & w_{33} \\\\\n",
    "w_{41} & w_{42} & w_{43}\n",
    "\\end{matrix}\\right) $  \n",
    "$ \\begin{aligned}\n",
    "Y & = XW \\\\\n",
    "& = \\left(\\begin{matrix}\n",
    "x_{11}w_{11}+x_{12}w_{21}+x_{13}w_{31}+x_{14}w_{41} &\n",
    "x_{11}w_{12}+x_{12}w_{22}+x_{13}w_{32}+x_{14}w_{42} &\n",
    "x_{11}w_{13}+x_{12}w_{23}+x_{13}w_{33}+x_{14}w_{43} \\\\\n",
    "x_{21}w_{11}+x_{22}w_{21}+x_{23}w_{31}+x_{24}w_{41} &\n",
    "x_{21}w_{12}+x_{22}w_{22}+x_{23}w_{32}+x_{24}w_{42} &\n",
    "x_{21}w_{13}+x_{22}w_{23}+x_{23}w_{33}+x_{24}w_{43}\n",
    "\\end{matrix}\\right) \\\\\n",
    "& = \\left(\\begin{matrix}\n",
    "y_{11} & y_{12} & y_{13} \\\\\n",
    "y_{21} & y_{22} & y_{23}\n",
    "\\end{matrix}\\right)\n",
    "\\end{aligned} $  \n",
    "$ L=f(y_{11})+f(y_{12})+f(y_{13})+f(y_{21})+f(y_{22})+f(y_{23})=\\sum_{i,j}f(y_{ij}) $  \n",
    "其中，$f(y_{ij})$是关于$y_{ij}$的函数，这些函数值累加起来形成目标函数$L$  \n",
    "请计算$\\dfrac{\\partial L}{\\partial W}$和$\\dfrac{\\partial L}{\\partial X}$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    ">**步骤2：考查$\\dfrac{\\partial L}{\\partial Y}$**  \n",
    "考虑到：$L=f(y_{11})+f(y_{12})+f(y_{13})+f(y_{21})+f(y_{22})+f(y_{23})$  \n",
    "可知：$\\dfrac{\\partial L}{\\partial f(y_{ij})}=1 $  \n",
    "$L$是标量，$Y$是$(2,3)$矩阵，导数也应是$(2,3)$矩阵，表示为：  \n",
    "$\\begin{aligned}\n",
    "\\dfrac{\\partial L}{\\partial Y} & =\n",
    "\\left(\\begin{matrix}\n",
    "\\dfrac{\\partial L}{\\partial y_{11}} &\n",
    "\\dfrac{\\partial L}{\\partial y_{12}} &\n",
    "\\dfrac{\\partial L}{\\partial y_{13}} \\\\\n",
    "\\dfrac{\\partial L}{\\partial y_{21}} &\n",
    "\\dfrac{\\partial L}{\\partial y_{22}} &\n",
    "\\dfrac{\\partial L}{\\partial y_{23}}\n",
    "\\end{matrix}\\right) \\\\ \\\\\n",
    "& =\\left(\\begin{matrix}\n",
    "\\dfrac{\\partial L}{\\partial f(y_{11})} \\cdot \\dfrac{\\partial f(y_{11})}{\\partial y_{11}} &\n",
    "\\dfrac{\\partial L}{\\partial f(y_{12})} \\cdot \\dfrac{\\partial f(y_{12})}{\\partial y_{12}} &\n",
    "\\dfrac{\\partial L}{\\partial f(y_{13})} \\cdot \\dfrac{\\partial f(y_{13})}{\\partial y_{13}} \\\\\n",
    "\\dfrac{\\partial L}{\\partial f(y_{21})} \\cdot \\dfrac{\\partial f(y_{21})}{\\partial y_{21}} &\n",
    "\\dfrac{\\partial L}{\\partial f(y_{22})} \\cdot \\dfrac{\\partial f(y_{22})}{\\partial y_{22}} &\n",
    "\\dfrac{\\partial L}{\\partial f(y_{23})} \\cdot \\dfrac{\\partial f(y_{23})}{\\partial y_{23}}\n",
    "\\end{matrix}\\right) \\\\ \\\\\n",
    "& =\\left(\\begin{matrix}\n",
    "\\dfrac{\\partial f(y_{11})}{\\partial y_{11}} &\n",
    "\\dfrac{\\partial f(y_{12})}{\\partial y_{12}} &\n",
    "\\dfrac{\\partial f(y_{13})}{\\partial y_{13}} \\\\\n",
    "\\dfrac{\\partial f(y_{21})}{\\partial y_{21}} &\n",
    "\\dfrac{\\partial f(y_{22})}{\\partial y_{22}} &\n",
    "\\dfrac{\\partial f(y_{23})}{\\partial y_{23}}\n",
    "\\end{matrix}\\right)\n",
    "\\end{aligned}$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    ">**步骤3：考查$\\dfrac{\\partial L}{\\partial x_{11}}$**  \n",
    "$ \\begin{aligned} \\\\\n",
    "\\begin{aligned}\n",
    "L = & f(y_{11})+f(y_{12})+f(y_{13})+f(y_{21})+f(y_{22})+f(y_{23}) \\\\\n",
    "= & f(x_{11}w_{11}+x_{12}w_{21}+x_{13}w_{31}+x_{14}w_{41}) +\n",
    "f(x_{11}w_{12}+x_{12}w_{22}+x_{13}w_{32}+x_{14}w_{42}) + \\\\\n",
    "& f(x_{11}w_{13}+x_{12}w_{23}+x_{13}w_{33}+x_{14}w_{43}) +\n",
    "f(x_{21}w_{11}+x_{22}w_{21}+x_{23}w_{31}+x_{24}w_{41}) + \\\\\n",
    "& f(x_{21}w_{12}+x_{22}w_{22}+x_{23}w_{32}+x_{24}w_{42}) +\n",
    "f(x_{21}w_{13}+x_{22}w_{23}+x_{23}w_{33}+x_{24}w_{43})\n",
    "\\end{aligned} \\end{aligned}$  \n",
    "$ \\begin{aligned} \\\\\n",
    "\\dfrac{\\partial L}{\\partial x_{11}} = &\n",
    "\\dfrac{\\partial{(f(y_{11})+f(y_{12})+f(y_{13})+f(y_{21})+f(y_{22})+f(y_{23}))}}{\\partial x_{11}} \\\\\n",
    "= & \\dfrac{\\partial{f(y_{11})}}{\\partial x_{11}} + \\dfrac{\\partial{f(y_{12})}}{\\partial x_{11}} + \\dfrac{\\partial{f(y_{13})}}{\\partial x_{11}} + \\dfrac{\\partial{f(y_{21})}}{\\partial x_{11}} + \\dfrac{\\partial{f(y_{22})}}{\\partial x_{11}} + \\dfrac{\\partial{f(y_{23})}}{\\partial x_{11}}\n",
    "\\end{aligned}$  \n",
    "而根据链式法则(注意，此时$y_{11}$ 和$x_{11}$ 都已是一元变量，因此可以使用通常意义上的链式法则)：  \n",
    "$ \\dfrac{\\partial f(y_{11})}{\\partial x_{11}}=\\dfrac{\\partial f(y_{11})}{\\partial y_{11}} \\cdot \\dfrac{\\partial y_{11}}{\\partial x_{11}}=\\dfrac{\\partial L}{\\partial y_{11}} \\cdot \\dfrac{\\partial y_{11}}{\\partial x_{11}} $(注意：$\\dfrac{\\partial L}{\\partial f(y_{ij})}=1 $  )  \n",
    "因此：  \n",
    "$\\begin{aligned}\n",
    "\\dfrac{\\partial L}{\\partial x_{11}} = & \\dfrac{\\partial L}{\\partial y_{11}} \\cdot \\dfrac{\\partial y_{11}}{\\partial x_{11}} + \n",
    "\\dfrac{\\partial L}{\\partial y_{12}} \\cdot \\dfrac{\\partial y_{12}}{\\partial x_{11}} +\n",
    "\\dfrac{\\partial L}{\\partial y_{13}} \\cdot \\dfrac{\\partial y_{13}}{\\partial x_{11}} +\n",
    "\\dfrac{\\partial L}{\\partial y_{21}} \\cdot \\dfrac{\\partial y_{21}}{\\partial x_{11}} +\n",
    "\\dfrac{\\partial L}{\\partial y_{22}} \\cdot \\dfrac{\\partial y_{22}}{\\partial x_{11}} +\n",
    "\\dfrac{\\partial L}{\\partial y_{23}} \\cdot \\dfrac{\\partial y_{23}}{\\partial x_{11}} \\\\\n",
    "= & \\dfrac{\\partial L}{\\partial y_{11}} \\cdot w_{11} +\n",
    "\\dfrac{\\partial L}{\\partial y_{12}} \\cdot w_{12} +\n",
    "\\dfrac{\\partial L}{\\partial y_{13}} \\cdot w_{13}\n",
    "\\end{aligned}$  \n",
    "类似的：  \n",
    "$ \\begin{aligned}\n",
    "\\dfrac{\\partial L}{\\partial x_{12}} = & \\dfrac{\\partial L}{\\partial y_{11}} \\cdot \\dfrac{\\partial y_{11}}{\\partial x_{12}} +\n",
    "\\dfrac{\\partial L}{\\partial y_{12}} \\cdot \\dfrac{\\partial y_{12}}{\\partial x_{12}} +\n",
    "\\dfrac{\\partial L}{\\partial y_{13}} \\cdot \\dfrac{\\partial y_{13}}{\\partial x_{12}} +\n",
    "\\dfrac{\\partial L}{\\partial y_{21}} \\cdot \\dfrac{\\partial y_{21}}{\\partial x_{12}} +\n",
    "\\dfrac{\\partial L}{\\partial y_{22}} \\cdot \\dfrac{\\partial y_{22}}{\\partial x_{12}} +\n",
    "\\dfrac{\\partial L}{\\partial y_{23}} \\cdot \\dfrac{\\partial y_{23}}{\\partial x_{12}} \\\\ = & \\dfrac{\\partial L}{\\partial y_{11}} * w_{21} +\n",
    "\\dfrac{\\partial L}{\\partial y_{12}} \\cdot w_{22} +\n",
    "\\dfrac{\\partial L}{\\partial y_{13}} \\cdot w_{23}\n",
    "\\end{aligned}$  \n",
    "$\\begin{aligned}\n",
    "\\dfrac{\\partial L}{\\partial x_{21}} = & \\dfrac{\\partial L}{\\partial y_{11}} \\cdot \\dfrac{\\partial y_{11}}{\\partial x_{21}} +\n",
    "\\dfrac{\\partial L}{\\partial y_{12}} \\cdot \\dfrac{\\partial y_{12}}{\\partial x_{21}} +\n",
    "\\dfrac{\\partial L}{\\partial y_{13}} \\cdot \\dfrac{\\partial y_{13}}{\\partial x_{21}} +\n",
    "\\dfrac{\\partial L}{\\partial y_{21}} \\cdot \\dfrac{\\partial y_{21}}{\\partial x_{21}} +\n",
    "\\dfrac{\\partial L}{\\partial y_{22}} \\cdot \\dfrac{\\partial y_{22}}{\\partial x_{21}} +\n",
    "\\dfrac{\\partial L}{\\partial y_{23}} \\cdot \\dfrac{\\partial y_{23}}{\\partial x_{21}} \\\\ = & \\dfrac{\\partial L}{\\partial y_{21}} \\cdot w_{11} +\n",
    "\\dfrac{\\partial L}{\\partial y_{22}} \\cdot w_{12} +\n",
    "\\dfrac{\\partial L}{\\partial y_{23}} \\cdot w_{13}\n",
    "\\end{aligned}$  \n",
    "等等"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    ">**步骤4：表达成矩阵形式**  \n",
    "$\\begin{aligned}\n",
    "\\dfrac{\\partial L}{\\partial X} \\\\ = & \\left(\\begin{matrix}\n",
    "\\dfrac{\\partial L}{\\partial x_{11}}&\\dfrac{\\partial L}{\\partial x_{12}}&\n",
    "\\dfrac{\\partial L}{\\partial x_{13}}&\\dfrac{\\partial L}{\\partial x_{14}} \\\\\n",
    "\\dfrac{\\partial L}{\\partial x_{21}}&\\dfrac{\\partial L}{\\partial x_{22}}&\n",
    "\\dfrac{\\partial L}{\\partial x_{23}}&\\dfrac{\\partial L}{\\partial x_{24}}\n",
    "\\end{matrix}\\right) \\\\ \\\\ = & \\left(\\begin{matrix}\n",
    "\\dfrac{\\partial L}{\\partial y_{11}} \\cdot w_{11} +\n",
    "\\dfrac{\\partial L}{\\partial y_{12}} \\cdot w_{12} +\n",
    "\\dfrac{\\partial L}{\\partial y_{13}} \\cdot w_{13} &\n",
    "\\dfrac{\\partial L}{\\partial y_{11}} \\cdot w_{21} +\n",
    "\\dfrac{\\partial L}{\\partial y_{12}} \\cdot w_{22} +\n",
    "\\dfrac{\\partial L}{\\partial y_{13}} \\cdot w_{23} \\to \\\\\n",
    "\\to \\dfrac{\\partial L}{\\partial y_{11}} \\cdot w_{31} +\n",
    "\\dfrac{\\partial L}{\\partial y_{12}} \\cdot w_{32} +\n",
    "\\dfrac{\\partial L}{\\partial y_{13}} \\cdot w_{33} &\n",
    "\\dfrac{\\partial L}{\\partial y_{11}} \\cdot w_{41} +\n",
    "\\dfrac{\\partial L}{\\partial y_{12}} \\cdot w_{42} +\n",
    "\\dfrac{\\partial L}{\\partial y_{13}} \\cdot w_{43} \\\\\n",
    "\\dfrac{\\partial L}{\\partial y_{21}} \\cdot w_{11} +\n",
    "\\dfrac{\\partial L}{\\partial y_{22}} \\cdot w_{12} +\n",
    "\\dfrac{\\partial L}{\\partial y_{23}} \\cdot w_{13} &\n",
    "\\dfrac{\\partial L}{\\partial y_{21}} \\cdot w_{21} +\n",
    "\\dfrac{\\partial L}{\\partial y_{22}} \\cdot w_{22} +\n",
    "\\dfrac{\\partial L}{\\partial y_{23}} \\cdot w_{23} \\to \\\\\n",
    "\\to \\dfrac{\\partial L}{\\partial y_{21}} \\cdot w_{31} +\n",
    "\\dfrac{\\partial L}{\\partial y_{22}} \\cdot w_{32} +\n",
    "\\dfrac{\\partial L}{\\partial y_{23}} \\cdot w_{33} &\n",
    "\\dfrac{\\partial L}{\\partial y_{21}} \\cdot w_{41} +\n",
    "\\dfrac{\\partial L}{\\partial y_{22}} \\cdot w_{42} +\n",
    "\\dfrac{\\partial L}{\\partial y_{23}} \\cdot w_{43}\n",
    "\\end{matrix}\\right) \\\\ \\\\ = & \\left(\\begin{matrix}\n",
    "\\dfrac{\\partial L}{\\partial y_{11}} &\n",
    "\\dfrac{\\partial L}{\\partial y_{12}} &\n",
    "\\dfrac{\\partial L}{\\partial y_{13}} \\\\\n",
    "\\dfrac{\\partial L}{\\partial y_{21}} &\n",
    "\\dfrac{\\partial L}{\\partial y_{22}} &\n",
    "\\dfrac{\\partial L}{\\partial y_{23}}\n",
    "\\end{matrix}\\right) \\cdot\n",
    "\\left(\\begin{matrix}\n",
    "w_{11} & w_{21} & w_{31} & w_{41} \\\\\n",
    "w_{12} & w_{22} & w_{32} & w_{42} \\\\\n",
    "w_{13} & w_{23} & w_{33} & w_{43}\n",
    "\\end{matrix}\\right) \\\\ \\\\ = & \\dfrac{\\partial L}{\\partial Y}W^T\n",
    "\\end{aligned}$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    ">**步骤5：结论**  \n",
    "* 类似的可得：$ \\dfrac{\\partial L}{\\partial W}=X^T\\dfrac{\\partial L}{\\partial Y} $\n",
    "* 有了上述推论，在不计算Jacobian矩阵的情况下，也能比较方便的求得偏导数。上述结论虽然是在$N=2, D=4, M=3$情况下得出的，但一般也具有普遍性"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
